import argparse
import json

import yaml
from torch.utils.data import DataLoader

from utils.datasets import *
from utils.utils import *


def test(data,
         weights=None,
         batch_size=16,
         imgsz=640,
         conf_thres=0.001,
         iou_thres=0.6,  # for NMS
         save_json=False,
         single_cls=False,
         augment=False,
         model=None,
         dataloader=None,
         fast=False,
         verbose=False):
    # Initialize/load model and set device
    if model is None:
        training = False
        device = torch_utils.select_device(opt.device, batch_size=batch_size)

        half = device.type != 'cpu'  # half precision only supported on CUDA

        # Remove previous
        for f in glob.glob('test_batch*.jpg'):
            os.remove(f)

        # Load model
        google_utils.attempt_download(weights)
        model = torch.load(weights, map_location="cpu")[
            'model'].float()  # load to FP32
        torch_utils.model_info(model)
        model.fuse()
        model.to(device)
        if half:
            model.half()  # to FP16

        if device.type != 'cpu' and torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

    else:  # called by train.py
        training = True
        device = next(model.parameters()).device  # get model device
        half = device.type != 'cpu'  # half precision only supported on CUDA
        if half:
            model.half()  # to FP16

    # Configure
    model.eval()
    with open(data) as f:
        data = yaml.load(f, Loader=yaml.FullLoader)  # model dict
    nc = 1 if single_cls else int(data['nc'])  # number of classes
    iouv = torch.linspace(0.5, 0.95, 10).to(
        device)  # iou vector for mAP@0.5:0.95
    # iouv = iouv[0].view(1)  # comment for mAP@0.5:0.95
    niou = iouv.numel()

    # Dataloader
    if dataloader is None:  # not training
        img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
        _ = model(
            img.half() if half else img) if device.type != 'cpu' else None  # run once

        fast |= conf_thres > 0.001  # enable fast mode
        # path to val/test images
        path = data['test'] if opt.task == 'test' else data['val']
        dataset = LoadImagesAndLabels(path,
                                      imgsz,
                                      batch_size,
                                      rect=True,  # rectangular inference
                                      single_cls=opt.single_cls,  # single class mode
                                      pad=0.5)  # padding
        batch_size = min(batch_size, len(dataset))
        # number of workers
        nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=nw,
                                pin_memory=True,
                                collate_fn=dataset.collate_fn)

    seen = 0
    names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}

    coco91class = coco80_to_coco91_class()
    s = ('%20s' + '%12s' * 6) % ('Class', 'Images',
                                 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
    p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
    loss = torch.zeros(3, device=device)
    jdict, stats, ap, ap_class = [], [], [], []
    for batch_i, (img, targets, paths, shapes) in enumerate(
            tqdm(dataloader, desc=s)):
        img = img.to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        targets = targets.to(device)
        nb, _, height, width = img.shape  # batch size, channels, height, width
        whwh = torch.Tensor([width, height, width, height]).to(device)

        # Disable gradients
        with torch.no_grad():
            # Run model
            t = torch_utils.time_synchronized()
            # inference and training outputs
            inf_out, train_out = model(img, augment=augment)
            t0 += torch_utils.time_synchronized() - t

            # Compute loss
            if training:  # if model has loss hyperparameters
                # GIoU, obj, cls
                loss += compute_loss([x.float()
                                      for x in train_out], targets, model)[1][:3]

            # Run NMS
            t = torch_utils.time_synchronized()
            output = non_max_suppression(
                inf_out,
                conf_thres=conf_thres,
                iou_thres=iou_thres,
                fast=fast)
            t1 += torch_utils.time_synchronized() - t

        # Statistics per image
        for si, pred in enumerate(output):
            labels = targets[targets[:, 0] == si, 1:]
            nl = len(labels)
            tcls = labels[:, 0].tolist() if nl else []  # target class
            seen += 1

            if pred is None:
                if nl:
                    stats.append(
                        (torch.zeros(
                            0,
                            niou,
                            dtype=torch.bool),
                            torch.Tensor(),
                            torch.Tensor(),
                            tcls))
                continue

            # Append to text file
            # with open('test.txt', 'a') as file:
            #    [file.write('%11.5g' * 7 % tuple(x) + '\n') for x in pred]

            # Clip boxes to image bounds
            clip_coords(pred, (height, width))

            # Append to pycocotools JSON dictionary
            if save_json:
                # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
                image_id = int(Path(paths[si]).stem.split('_')[-1])
                box = pred[:, :4].clone()  # xyxy
                scale_coords(img[si].shape[1:], box, shapes[si]
                             [0], shapes[si][1])  # to original shape
                box = xyxy2xywh(box)  # xywh
                box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
                for p, b in zip(pred.tolist(), box.tolist()):
                    jdict.append({'image_id': image_id,
                                  'category_id': coco91class[int(p[5])],
                                  'bbox': [round(x, 3) for x in b],
                                  'score': round(p[4], 5)})

            # Assign all predictions as incorrect
            correct = torch.zeros(
                pred.shape[0],
                niou,
                dtype=torch.bool,
                device=device)
            if nl:
                detected = []  # target indices
                tcls_tensor = labels[:, 0]

                # target boxes
                tbox = xywh2xyxy(labels[:, 1:5]) * whwh

                # Per target class
                for cls in torch.unique(tcls_tensor):
                    ti = (cls == tcls_tensor).nonzero(
                    ).view(-1)  # prediction indices
                    pi = (cls == pred[:, 5]).nonzero(
                    ).view(-1)  # target indices

                    # Search for detections
                    if pi.shape[0]:
                        # Prediction to target ious
                        ious, i = box_iou(pred[pi, :4], tbox[ti]).max(
                            1)  # best ious, indices

                        # Append detections
                        for j in (ious > iouv[0]).nonzero():
                            d = ti[i[j]]  # detected target
                            if d not in detected:
                                detected.append(d)
                                # iou_thres is 1xn
                                correct[pi[j]] = ious[j] > iouv
                                if len(
                                        detected) == nl:  # all targets already located in image
                                    break

            # Append statistics (correct, conf, pcls, tcls)
            stats.append(
                (correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))

        # Plot images
        if batch_i < 1:
            f = 'test_batch%g_gt.jpg' % batch_i  # filename
            plot_images(img, targets, paths, f, names)  # ground truth
            f = 'test_batch%g_pred.jpg' % batch_i
            plot_images(
                img,
                output_to_target(
                    output,
                    width,
                    height),
                paths,
                f,
                names)  # predictions

    # Compute statistics
    stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy
    if len(stats):
        p, r, ap, f1, ap_class = ap_per_class(*stats)
        p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(
            1)  # [P, R, AP@0.5, AP@0.5:0.95]
        mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
        nt = np.bincount(stats[3].astype(np.int64),
                         minlength=nc)  # number of targets per class
    else:
        nt = torch.zeros(1)

    # Print results
    pf = '%20s' + '%12.3g' * 6  # print format
    print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))

    # Print results per class
    if verbose and nc > 1 and len(stats):
        for i, c in enumerate(ap_class):
            print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))

    # Print speeds
    t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + \
        (imgsz, imgsz, batch_size)  # tuple
    if not training:
        print(
            'Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' %
            t)

    # Save JSON
    if save_json and map50 and len(jdict):
        imgIds = [int(Path(x).stem.split('_')[-1])
                  for x in dataloader.dataset.img_files]
        f = 'detections_val2017_%s_results.json' % \
            (weights.split(os.sep)[-1].replace('.pt', '')
             if weights else '')  # filename
        print('\nCOCO mAP with pycocotools... saving %s...' % f)
        with open(f, 'w') as file:
            json.dump(jdict, file)

        try:
            from pycocotools.coco import COCO
            from pycocotools.cocoeval import COCOeval

            # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
            # initialize COCO ground truth api
            cocoGt = COCO(
                glob.glob('../coco/annotations/instances_val*.json')[0])
            cocoDt = cocoGt.loadRes(f)  # initialize COCO pred api

            cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
            cocoEval.params.imgIds = imgIds  # image IDs to evaluate
            cocoEval.evaluate()
            cocoEval.accumulate()
            cocoEval.summarize()
            # update results (mAP@0.5:0.95, mAP@0.5)
            map, map50 = cocoEval.stats[:2]
        except BaseException:
            print('WARNING: pycocotools must be installed with numpy==1.17 to run correctly. '
                  'See https://github.com/cocodataset/cocoapi/issues/356')

    # Return results
    maps = np.zeros(nc) + map
    for i, c in enumerate(ap_class):
        maps[c] = ap[i]
    return (mp, mr, map50, map, *(loss.cpu() /
                                  len(dataloader)).tolist()), maps, t


if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='test.py')
    parser.add_argument(
        '--weights',
        type=str,
        default='',
        help='model.pt path')
    parser.add_argument(
        '--data',
        type=str,
        default='data/voc.yaml',
        help='*.data path')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=16,
        help='size of each image batch')
    parser.add_argument(
        '--img-size',
        type=int,
        default=640,
        help='inference size (pixels)')
    parser.add_argument(
        '--conf-thres',
        type=float,
        default=0.001,
        help='object confidence threshold')
    parser.add_argument(
        '--iou-thres',
        type=float,
        default=0.6,
        help='IOU threshold for NMS')
    parser.add_argument(
        '--save-json',
        action='store_true',
        help='save a cocoapi-compatible JSON results file')
    parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
    parser.add_argument(
        '--device',
        default='',
        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument(
        '--single-cls',
        action='store_true',
        help='treat as single-class dataset')
    parser.add_argument(
        '--augment',
        action='store_true',
        help='augmented inference')
    parser.add_argument(
        '--verbose',
        default=True,
        action='store_true',
        help='report mAP by class')
    opt = parser.parse_args()
    opt.img_size = check_img_size(opt.img_size)
    opt.save_json = opt.save_json or opt.data.endswith('coco.yaml')
    opt.data = check_file(opt.data)  # check file
    print(opt)

    # task = 'val', 'test', 'study'
    if opt.task in ['val', 'test']:  # (default) run normally
        test(opt.data,
             opt.weights,
             opt.batch_size,
             opt.img_size,
             opt.conf_thres,
             opt.iou_thres,
             opt.save_json,
             opt.single_cls,
             opt.augment,
             verbose=opt.verbose)

    elif opt.task == 'study':  # run over a range of settings and save/plot
        for weights in ['yolov5s.pt', 'yolov5m.pt',
                        'yolov5l.pt', 'yolov5x.pt']:
            f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(
                weights).stem)  # filename to save to
            x = list(range(288, 896, 64))  # x axis
            y = []  # y axis
            for i in x:  # img-size
                print('\nRunning %s point %s...' % (f, i))
                r, _, t = test(opt.data, weights, opt.batch_size,
                               i, opt.conf_thres, opt.iou_thres, opt.save_json)
                y.append(r + t)  # results and times
            np.savetxt(f, y, fmt='%10.4g')  # save
        os.system('zip -r study.zip study_*.txt')
        # plot_study_txt(f, x)  # plot
